#!/usr/bin/python

################################################################
### Align transcriptions with lattices.
################################################################

import sys,os,re,glob,math,glob,signal,traceback,sqlite3
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from scipy.ndimage import interpolation,measurements
from pylab import *
from optparse import OptionParser
from multiprocessing import Pool
import ocrolib
from ocrolib import number_of_processors,die,docproc
from scipy import stats
from ocrolib import dbhelper,ocroio
from ocrolib import ligatures,rect_union,Record
import multiprocessing
import ocrofst
from ocrofst import fstutils

signal.signal(signal.SIGINT,lambda *args:sys.exit(1))

import argparse
parser = argparse.ArgumentParser(description = """
%prog [-s gt] [-p] [args] *.gt.txt
""")

parser.add_argument("-x","--extract",help="extract characters",default=None)
parser.add_argument("-X","--noextract",help="don't actually write files",action="store_true")
parser.add_argument("-N","--normalize",help="size normalize characters",action="store_true")
parser.add_argument("-B","--beam",help="size of beam",type=int,default=100)
parser.add_argument("-T","--confusions",help="confusion table (lattice output cost)",default=None)
parser.add_argument("-r","--rejects",help="save rejects as well",action="store_true")
parser.add_argument("-F","--usefst",action="store_true",help="use fst instead of lattice file")

parser.add_argument("-f","--filelist",help="list of images and langmods",default=None)
parser.add_argument("-g","--gt",help="extension for ground truth",default=None)
parser.add_argument("-p","--pagegt",help="arguments are page ground truth",action="store_true")
parser.add_argument("--ppat",default=".*",help="file pattern to match with -p")
parser.add_argument("-l","--langmod",help="language model",default=None)

parser.add_argument("-s","--suffix",help="output suffix for writing result",default=None)
parser.add_argument("-O","--overwrite",help="overwrite outputs",action="store_true")

parser.add_argument("-P","--perc",help="percentile for reporting statistics",type=float,default=90.0)
parser.add_argument("-M","--maxperc",help="maximum cost at percentile",type=float,default=5.0)
parser.add_argument("-A","--maxavg",help="maximum average cost",type=float,default=5.0)
parser.add_argument("--minlength",help="minimum length for alignment",type=int,default=10)

parser.add_argument("-c","--mmcost",help="mismatch cost",type=float,default=3.0)
parser.add_argument('-R','--rcost',help="single character mismatch threshold",type=float,default=8.0)
parser.add_argument('--lcost2',help="ligature cost",type=float,default=2.0)
parser.add_argument('--lcost3',help="ligature cost",type=float,default=3.0)
parser.add_argument("-C","--edcost",help="edit cost",type=float,default=9.0)
parser.add_argument("--trimcost",help="cost for removing characters at beginning/end",default=2.0)
parser.add_argument("--maxtrim",help="maximum number of chars permitted to be trimmed",default=1)
parser.add_argument("--maxlig",help="maximum number of ligatures on line",default=2)

parser.add_argument("-Q","--parallel",type=int,default=multiprocessing.cpu_count(),help="number of parallel processes to use")

parser.add_argument("-D","--Display",help="display",action="store_true")
parser.add_argument("--debug_line",action="store_true")
parser.add_argument("--debug_rawalign",action="store_true")
parser.add_argument("--debug_aligned",action="store_true")
parser.add_argument("--debug_select",action="store_true")
parser.add_argument('--dgrid',default=8,help="grid size for display")
parser.add_argument('--nlines',default=999999999,type=int,help="max # lines to process (for testing)")

parser.add_argument("args",default=[],nargs='*',help="input lines)")
# args = parser.parse_args(["-D","-y","-x","temp.db"])
args = parser.parse_args()

# a list of confusions and their costs
# format is classifier output and actual output

confusions = []

if args.confusions is not None:
    with open(args.confusions) as stream:
        for line in stream.readlines():
            f = line.split()
            assert len(f)==3
            confusions.append((f[0],f[1],float(f[2])))

class LigTable:
    """Ligature encoding table.  These ligatures are used just locally within
    the aligner.  No data is actually stored in this format."""
    def __init__(self):
        self.ncodes = 0
        self.code2chr = {}
        self.chr2code = {}
        for i in range(128):
            self.ord(chr(i))
    def ord(self,c):
        if c in self.chr2code: 
            return self.chr2code[c]
        code = self.ncodes
        self.ncodes += 1
        self.code2chr[code] = c
        self.chr2code[c] = code
        return code
    def chr(self,c):
        return self.code2chr[c]

ltable = LigTable()

print repr(args.extract)
if args.extract is not None:
    assert not os.path.exists(args.extract),"%s: already exists; please remove"%args.extract

class NullDb:
    def __enter__(self,*args):
        pass
    def __exit__(self,*args):
        pass
    def insert(self,*args,**kw):
        pass
    def insertAll(self,*args,**kw):
        pass

class DbWriter:
    def __init__(self,fname):
        if not os.path.exists(fname):
            db = sqlite3.connect(fname,timeout=600.0)
            dbhelper.charcolumns(db,"chars")
            db.commit()
        else:
            db = sqlite3.connect(fname,timeout=600.0)
        self.db = db
    def __enter__(self):
        return self
    def insert(self,image,cls,cost=0.0,count=0,file=None,lgeo=None,rel=None,bbox=None):
        dbhelper.dbinsert(self.db,"chars",
                          image=dbhelper.image2blob(image),
                          cost=float(cost),
                          cls=cls,
                          count=1,
                          file=file,
                          lgeo="%g %g %g"%lgeo if lgeo else None,
                          rel="%g %g %g"%rel if rel else None,
                          bbox="%d %d %d %d"%bbox if bbox else None)
    def insertAll(self,records):
        for r in records:
            self.insert(image=r.image,cls=r.cls,cost=r.cost,file=r.file,lgeo=r.lgeo,rel=r.rel,bbox=r.bbox)
    def __exit__(self,*args):
        self.db.commit()
        self.db.close()

def csnormalize(image,f=0.75):
    bimage = 1*(image>mean([amax(image),amin(image)]))
    w,h = bimage.shape
    [xs,ys] = mgrid[0:w,0:h]
    s = sum(bimage)
    if s<1e-4: return image
    s = 1.0/s
    cx = sum(xs*bimage)*s
    cy = sum(ys*bimage)*s
    sxx = sum((xs-cx)**2*bimage)*s
    sxy = sum((xs-cx)*(ys-cy)*bimage)*s
    syy = sum((ys-cy)**2*bimage)*s
    w,v = eigh(array([[sxx,sxy],[sxy,syy]]))
    l = sqrt(amax(w))
    scale = f*max(image.shape)/(4.0*l)
    m = array([[1.0/scale,0],[0.0,1.0/scale]])
    w,h = image.shape
    c = array([cx,cy])
    d = c-dot(m,array([w/2,h/2]))
    image = interpolation.affine_transform(image,m,offset=d,order=1)
    return image

def csnormalize1(image,size=32):
    if amax(image)==amin(image):
        image = zeros((size,size))
    else:
        image = docproc.isotropic_rescale(image,size)
        image = csnormalize(image)
        image = image*1.0/amax(image)
    return image

import tables,fcntl

def uencode(s):
    """Convert a short unicode string into a 64bit integer"""
    assert len(s)<=4
    result = 0
    for c in s[len(s)-1::-1]: result = (result<<16)|ord(c)
    return result

from tables import *

class Hdf5Writer:
    def __init__(self,fname,mode="w",size=32):
        self.size = size
        h5 = tables.openFile(fname,mode)
        h5.createEArray(h5.root,'patches',Float32Atom(),shape=(0,size,size),filters=Filters(9))
        h5.createEArray(h5.root,'classes',Int64Atom(),shape=(0,),filters=tables.Filters(9))
        h5.createEArray(h5.root,'costs',Float32Atom(),shape=(0,),filters=tables.Filters(9))
        h5.createVLArray(h5.root,'files',StringAtom(120),filters=Filters(9))
        h5.createEArray(h5.root,'rel',Float32Atom(shape=(3,)),shape=(0,),filters=tables.Filters(9))
        h5.createEArray(h5.root,'bboxes',Float32Atom(shape=(4,)),shape=(0,),filters=Filters(9))
        # h5.createVLArray(h5.root,'cclasses',StringAtom(120),filters=Filters(9))
        # h5.createEArray(h5.root,'segids',Int16Atom(),shape=(0,),filters=tables.Filters(9))
        # h5.createEArray(h5.root,'clusters',Int32Atom(),shape=(0,),filters=tables.Filters(9))
        self.h5 = h5
    def __enter__(self):
        return self
    def insert(self,image,cls,cost=0.0,count=0,file=None,lgeo=None,rel=None,bbox=None):
        assert image.shape==(self.size,self.size),"wrong image shape: %s"%(image.shape,)
        # ion(); clf(); gray(); imshow(image); ginput(1,0.1)
        h5 = self.h5
        h5.root.patches.append([image])
        h5.root.classes.append([uencode(cls)])
        h5.root.costs.append([cost])
        h5.root.files.append(file)
        h5.root.rel.append([array(rel,'f')])
        h5.root.bboxes.append([array(bbox,'f')])
    def insertAll(self,records):
        if len(records)<1: return
        h5 = self.h5
        h5.root.patches.append(array([r.image for r in records],'f'))
        h5.root.classes.append(array([uencode(r.cls) for r  in records],'int64'))
        h5.root.costs.append([r.cost for r in records])
        for r in records: h5.root.files.append([r.file])
        h5.root.rel.append([array(r.rel,'f') for r in records])
        h5.root.bboxes.append([array(r.bbox,'f') for r in records])
    def __exit__(self,*args):
        self.h5.close()
        self.h5 = None

import codecs
import openfst
from fstutils import explode_transcription,epsilon,space,sigma,add_between,optimize_openfst,openfst2ocrofst

def alignment_fst(line):
    fst = ocrofst.OcroFST()
    state = fst.AddState()
    fst.SetStart(state)
    states = [state]
    for i in range(len(line)):
        states.append(fst.AddState())
    ntrim = min(len(line)/4,args.maxtrim)
    for i in range(ntrim):
        fst.AddArc(states[i],epsilon,ord("~"),args.trimcost,states[i+1])
    for i in range(len(line)-ntrim,len(line)):
        fst.AddArc(states[i],epsilon,ord("~"),args.trimcost,states[i+1])
    for i in range(len(line)):
        s = line[i]
        c = ord(s)
        start = states[i]
        next = states[i+1]

        # space is special (since we use separate skip/insertion self)

        # insertion of space
        fst.AddArc(next,space,space,0.0,next)
        # insertion of other character
        fst.AddArc(start,sigma,ord("~"),args.edcost,start)

        if s==" ":
            # space transition
            fst.AddArc(start,space,space,0.0,next)
            # skip space
            fst.AddArc(start,epsilon,space,0.0,next)
            continue

        if s in ["~","_"]:
            # common ground-truth indicators of errors, unrecognizable characters
            fst.AddArc(start,sigma,ord("~"),4.0,start) # also allow repetition with some cost
            fst.AddArc(start,sigma,ord("~"),0.0,next)
            continue

        # add character transition
        fst.AddArc(start,c,c,0.0,next)
        # mismatch between input and transcription
        fst.AddArc(start,sigma,ord("_"),args.rcost,next)
        # deletion in lattice
        fst.AddArc(start,epsilon,ord("~"),args.edcost,next)
        # insertion in lattice
        fst.AddArc(start,sigma,epsilon,args.edcost,next)

    # explicit transition for confusions; note that multi-character
    # outputs are always treated as ligatures

    for i in range(0,len(line)):
        for u,v,cost in confusions:
            if i+len(v)>len(line): continue
            if line[i:i+len(v)]==v:
                start = states[i]
                end = states[i+len(v)]
                for j in range(len(u)-1):
                    next = fst.AddState()
                    fst.AddArc(start,ord(u[j]),epsilon,0.0,next)
                    start = next
                fst.AddArc(start,ord(u[-1]),ltable.ord(v),cost,end)

    # explicit transitions for ligatures: use epsilon on input, use encoded ligatures on output,

    if args.lcost2<inf:
        for i in range(0,len(line)-2):
            s = line[i:i+2]
            if " " in s: continue
            start = states[i]
            catch = fst.AddState()
            fst.AddArc(start,ord("~"),ltable.ord(s),0.0,catch)
            fst.AddArc(start,sigma,ltable.ord(s),args.lcost2,catch)
            fst.AddArc(catch,ord(line[i+2]),ord(line[i+2]),0.0,states[i+3])

    if args.lcost3<inf:
        for i in range(0,len(line)-3):
            s = line[i:i+3]
            if " " in s: continue
            start = states[i]
            catch = fst.AddState()
            fst.AddArc(start,ord("~"),ltable.ord(s),0.0,next)
            fst.AddArc(start,sigma,ltable.ord(s),args.lcost3,next)
            fst.AddArc(catch,ord(line[i+3]),ord(line[i+3]),0.0,states[i+4])

    last = states[-1]
    fst.AddArc(last,sigma,epsilon,args.edcost,last)
    fst.SetFinal(last,0.0)
    # print line; fst.save("_aligner.fst"); sys.exit(0)
    return fst

debug = 0

def align1(job):
    fname,gtfile = job
    # read the ground truth data and construct an FST
    if not os.path.exists(gtfile):
        print "*"+gtfile,": NOT FOUND"
        return
    with open(gtfile) as stream:
        gttext = stream.read()[:-1]
    gttext = re.sub("[\001-\011\013-\037]","~",gttext)


    if args.usefst:
        fst_file = ocrolib.fvariant(fname,"fst")
        if not os.path.exists(fst_file):
            print fst_file+": does not exist"
            return
        fst = ocrofst.OcroFST()
        fst.load(fst_file)
        fname = fst_file
    else:
        lattice_file = ocrolib.fvariant(fname,"lattice")
        if not os.path.exists(lattice_file):
            print lattice_file+": does not exist"
            return
        gr = Lattice()
        with open(lattice_file) as stream:
            gr.loadLattice(stream)
        fst = gr.getLatticeAsFST()
        fname = lattice_file

    bestcost = 1e38
    bestline = None
    bestfst = None
    best = None
    for i,line in enumerate(gttext.split("\n")):
        line = line.strip()
        line = re.sub(r'[~_ \t]+',' ',line)
        if len(line)<=1: continue
        gtfst = alignment_fst(line)
        # actually perform the alignment
        result = ocrofst.beam_search(fst,gtfst,100)
        if args.debug_line:
            print "line",i,len(result[0]),sum(result[4])
        if len(result[0])<=1: continue
        avg = sum(result[4])/len(result[0]+10.0)
        if avg>=bestcost: continue
        bestcost = avg
        bestline = line
        bestfst = gtfst
        best = result

    if len(bestline)<args.minlength:
        print "*"+fst_file,": SHORT/%d"%len(bestline)
        return

    if best is None:
        print "*"+fst_file,gtfile,": BEAM SEARCH FAILED"
        return

    v1,v2,ins,outs,costs = best

    if args.debug_rawalign:
        for i in range(len(v1)):
            print "raw-align %3d [%3d %3d] (%3d %3d) %6.2f"%(i,v1[i],v2[i],ins[i]>>16,ins[i]&0xffff,costs[i]),unichr(outs[i])

    sresult = []
    scosts = []
    segs = []

    n = len(ins)
    i = 1
    while i<n:
        if outs[i]==ord(" "):
            sresult.append(" ")
            scosts.append(costs[i])
            segs.append((0,0))
            i += 1
            continue
        # pick up ligatures indicated by the recognizer (multiple sequential 
        # output characters with the same input segments)
        j = i+1
        while j<n and ins[j]==ins[i]:
            j += 1
        cls = "".join([ltable.chr(c) for c in outs[i:j]])
        sresult.append(cls)
        scosts.append(sum(costs[i:j]))
        start = (ins[i]>>16)
        end = (ins[i]&0xffff)
        segs.append((start,end))
        i = j

    if len(sresult)<args.minlength:
        print "*"+fst_file,": OUTSHORT/%d"%len(sresult)
        return

    if args.debug_aligned:
        for i,row in enumerate(zip(sresult,scosts,segs)):
            print "aligned",i,row

    ml = max([len(x) for x in sresult])
    lig = sum([len(x)>1 for x in sresult])
    bad = sum([x=="~" for x in sresult])
    perc = stats.scoreatpercentile(costs,args.perc)
    avg = mean(costs)
    skip = (perc > args.maxperc or avg>args.maxavg or ml>3 or lig*1.0/len(sresult)>0.1 or bad*1.0/(3.0+len(sresult))>0.1 or lig>=args.maxlig)
    aligned = fstutils.implode_transcription(sresult,maxlig=100)
    if skip:
        print "%c%s %6.2f %6.2f:"%("*" if skip else " ",fname,perc,avg),""
    else:
        print "%c%s %6.2f %6.2f:"%("*" if skip else " ",fname,perc,avg),aligned

    if skip: return

    # read the raw segmentation
    rseg_file = ocrolib.fvariant(fname,"rseg")
    if not os.path.exists(rseg_file):
        print "*"+rseg_file,": NOT FOUND"
        return
    rseg = ocroio.read_line_segmentation(rseg_file)
    rseg_boxes = docproc.seg_boxes(rseg)

    # Now run through the segments and create a table that maps rseg
    # labels to the corresponding output element.

    assert len(sresult)==len(segs)
    assert len(scosts)==len(segs)

    bboxes = []

    rmap = zeros(amax(rseg)+1,'i')
    for i in range(1,len(segs)):
        start,end = segs[i]
        if start==0 or end==0: continue
        rmap[start:end+1] = i
        bboxes.append(rect_union(rseg_boxes[start:end+1]))
    assert rmap[0]==0

    cseg = zeros(rseg.shape,'i')
    for i in range(cseg.shape[0]):
        for j in range(cseg.shape[1]):
            cseg[i,j] = rmap[rseg[i,j]]

    assert len(segs)==len(sresult) 
    assert len(segs)==len(scosts)

    assert amin(cseg)==0,"amin(cseg)!=0 (%d,%d)"%(amin(cseg),amax(cseg))

    # first work on the list output; here, one list element should
    # correspond to each cost
    result = sresult
    assert len(result)==len(scosts),\
        "output length %d differs from cost length %d"%(len(result),len(costs))
    assert amax(cseg)<len(result),\
        "amax(cseg) %d not consistent with output length %d"%(amax(cseg),len(output_l))

    # if there are spaces at the end, trim them (since they will never have a corresponding cseg)

    while len(result)>0 and result[-1]==" ":
        result = result[:-1]
        costs = costs[:-1]

    perc = stats.scoreatpercentile(costs,50)
    avg = mean(costs)
    skip = (perc > 10.0 or avg>10.0)

    if args.suffix is not None:
        cseg_file = ocrolib.fvariant(fname,"cseg",args.suffix)
        if not args.overwrite:
            if os.path.exists(cseg_file): die("%s: already exists",cseg_file)
        ocrolib.write_line_segmentation(cseg_file,cseg)
        # aligned text line
        ocrolib.write_text(ocrolib.fvariant(fname,"txt","aligned"),aligned)
        # per character costs
        with ocrolib.fopen(fname,"costs",args.suffix,mode="w") as stream:
            for i in range(len(costs)):
                stream.write("%d %g\n"%(i,costs[i]))
        # true ground truth (best-matching line in the original transcription)
        ocrolib.write_text(ocrolib.fvariant(fname,"txt",args.suffix),bestline)

    result = []
    iraw = 0
    if args.Display:
        ion(); clf()

    line = ocrolib.read_image_gray(ocrolib.fvariant(fname,"png"))
    line = amax(line)-line
    lgeo = docproc.seg_geometry(rseg)
    grouper = ocrolib.Grouper()
    grouper.setSegmentation(rseg)
    results = []
    for i in range(grouper.length()):
        raw,mask = grouper.extractWithMask(line,i,dtype='B')
        # ion(); gray(); subplot(121); imshow(raw); subplot(122); imshow(mask); ginput(1,3)
        start = grouper.start(i)
        end = grouper.end(i)
        bbox = grouper.boundingBox(i)
        y0,x0,y1,x1 = bbox
        rel = docproc.rel_char_geom((y0,y1,x0,x1),lgeo)
        ry,rw,rh = rel
        assert rw>0 and rh>0
        if (start,end) in segs:
            index = segs.index((start,end))
            cls = sresult[index]
            cost = costs[index]
        else:
            cls = "~"
            cost = 0.0
        if cls!="~":
            if args.Display:
                iraw += 1
                subplot(args.dgrid,args.dgrid,iraw)
                gray(); imshow(raw)
                #gca().text(0.1,-0.1,"%d/%s"%(i,cls),transform=ax.transAxes,color='green')
                gca().set_frame_on(False)
                if cost>0.2:
                    ylabel("%d "%int(10*cost),color='red',size=10)
                xlabel("%-3s %d"%(cls,i),color='blue',size=10)
                xticks([])
                yticks([])
                print "output %3d %3d cls %-5s cost %6.2f    "%(start,end,cls,cost),
                print "y %6.2f w %6.2f h %6.2f"%(rel[0],rel[1],rel[2])
        if cls!="~" or args.rejects:
            if args.normalize: raw = csnormalize1(raw)
            results.append(Record(image=raw,
                                  cost=float(cost),
                                  cls=cls,
                                  count=1,
                                  file=fname,
                                  lgeo=lgeo,
                                  rel=rel,
                                  bbox=bbox))
    if args.Display: ginput(1,10000)
    return results

def safe_align1(job):
    try:
        return align1(job)
    except:
        traceback.print_exc()
        return []

jobs = []

if args.filelist:
    for line in open(args.filelist).readlines():
        image,lmod = line.strip().split(2)
        jobs.append((image,lmod))
elif args.pagegt:
    if len(args.args)==1 and os.path.isdir(args.args[0]):
        files = glob.glob(args[0]+"/*.gt.txt")
    else:
        files = args.args
    missing_pagedir = 0
    for arg in files:
        if not os.path.exists(arg):
            print "# %s: not found"%arg
            continue
        base,_ = ocrolib.allsplitext(arg)
        if not os.path.exists(base) or not os.path.isdir(base):
            # print "# %s: dir not found"
            # sys.stdout.flush()
            missing_pagedir += 1
            continue
        lines = glob.glob(base+"/??????.png")
        for line in lines:
            if not re.search(args.ppat,line): 
                continue
            jobs.append((line,arg))
    print "number of page files",len(files),"number missing line dirs",missing_pagedir
elif args.langmod:
    for arg in args.args:
        jobs.append((arg,args.langmod))
elif args.gt is not None:
    count = 0
    allfiles = []
    for arg in args.args:
        if os.path.isdir(arg) or os.path.isdir(arg+"/."):
            files = sorted(glob.glob(arg+"/????/??????.png"))
            print "adding",len(files),"files from",arg
            assert len(files)>0
            allfiles += files
        else:
            files = sorted(glob.glob(arg))
            assert len(files)>0
            allfiles += files
            count += len(files)
    print "added",count,"files directly"
    for arg in allfiles:
        path,ext = ocrolib.allsplitext(arg)
        p = path+args.gt
        if not os.path.exists(arg):
            print arg,"not found"
            continue
        if not os.path.exists(p): 
            print p,"not found"
            continue
        jobs.append((arg,p))
else:
    raise Exception("you need to specify what kind of groundtruth you want to align with (-f, -p, -l, -g)")

print "got",len(jobs),"jobs"
if len(jobs)>args.nlines: jobs = jobs[:args.nlines]

extract_db = NullDb()
if args.extract is not None:
    if args.extract[-3:]==".db":
        extract_db = DbWriter(args.extract)
    elif args.extract[-3:]==".h5":
        extract_db = Hdf5Writer(args.extract)
        args.normalize = 1
    else:
        print "unknown database extension",args.extract
        sys.exit(1)

with extract_db as db:
    if args.parallel<2:
        for arg in jobs:
            results = align1(arg)
            if results is None: continue
            for r in results:
                extract_db.insert(image=r.image,cls=r.cls,cost=r.cost,count=1,file=r.file,lgeo=r.lgeo,rel=r.rel,bbox=r.bbox)
    else:
        pool = Pool(processes=args.parallel)
        n = 0
        csize = 100
        for s in range(0,len(jobs),csize):
            for results in pool.imap_unordered(safe_align1,jobs[s:min(s+csize,len(jobs))]):
                n += 1
                if results is None: continue
                extract_db.insertAll(results)
            print "==========",s,len(jobs),"=========="
